{
  "nbformat": 4,
  "nbformat_minor": 0,
  "metadata": {
    "colab": {
      "provenance": [],
      "gpuType": "T4"
    },
    "kernelspec": {
      "name": "python3",
      "display_name": "Python 3"
    },
    "language_info": {
      "name": "python"
    },
    "accelerator": "GPU",
    "widgets": {
      "application/vnd.jupyter.widget-state+json": {
        "2750023fb2bc420c875b3fde2cef2843": {
          "model_module": "@jupyter-widgets/controls",
          "model_name": "HBoxModel",
          "model_module_version": "1.5.0",
          "state": {
            "_dom_classes": [],
            "_model_module": "@jupyter-widgets/controls",
            "_model_module_version": "1.5.0",
            "_model_name": "HBoxModel",
            "_view_count": null,
            "_view_module": "@jupyter-widgets/controls",
            "_view_module_version": "1.5.0",
            "_view_name": "HBoxView",
            "box_style": "",
            "children": [
              "IPY_MODEL_d942639eecdf4f3c955a5ceabb2dd012",
              "IPY_MODEL_21aed97b90ea496dafbbfde642b8da3d",
              "IPY_MODEL_c38852bb4f82461dbc2f2ad28949e04f"
            ],
            "layout": "IPY_MODEL_36b15a47acca4d19ad19aa7a75d6adc4"
          }
        },
        "d942639eecdf4f3c955a5ceabb2dd012": {
          "model_module": "@jupyter-widgets/controls",
          "model_name": "HTMLModel",
          "model_module_version": "1.5.0",
          "state": {
            "_dom_classes": [],
            "_model_module": "@jupyter-widgets/controls",
            "_model_module_version": "1.5.0",
            "_model_name": "HTMLModel",
            "_view_count": null,
            "_view_module": "@jupyter-widgets/controls",
            "_view_module_version": "1.5.0",
            "_view_name": "HTMLView",
            "description": "",
            "description_tooltip": null,
            "layout": "IPY_MODEL_489e34fc92e041f39771ff701ddd6969",
            "placeholder": "​",
            "style": "IPY_MODEL_00a24949f2324b8f855ec5bfdc92d434",
            "value": "Epoch 0 / 1: 100%"
          }
        },
        "21aed97b90ea496dafbbfde642b8da3d": {
          "model_module": "@jupyter-widgets/controls",
          "model_name": "FloatProgressModel",
          "model_module_version": "1.5.0",
          "state": {
            "_dom_classes": [],
            "_model_module": "@jupyter-widgets/controls",
            "_model_module_version": "1.5.0",
            "_model_name": "FloatProgressModel",
            "_view_count": null,
            "_view_module": "@jupyter-widgets/controls",
            "_view_module_version": "1.5.0",
            "_view_name": "ProgressView",
            "bar_style": "success",
            "description": "",
            "description_tooltip": null,
            "layout": "IPY_MODEL_ff1caf7de77b4b279a3b6d35669b79dd",
            "max": 219,
            "min": 0,
            "orientation": "horizontal",
            "style": "IPY_MODEL_4c653f14317240edbea90f9b198c10cf",
            "value": 219
          }
        },
        "c38852bb4f82461dbc2f2ad28949e04f": {
          "model_module": "@jupyter-widgets/controls",
          "model_name": "HTMLModel",
          "model_module_version": "1.5.0",
          "state": {
            "_dom_classes": [],
            "_model_module": "@jupyter-widgets/controls",
            "_model_module_version": "1.5.0",
            "_model_name": "HTMLModel",
            "_view_count": null,
            "_view_module": "@jupyter-widgets/controls",
            "_view_module_version": "1.5.0",
            "_view_name": "HTMLView",
            "description": "",
            "description_tooltip": null,
            "layout": "IPY_MODEL_a5e6278ca52d47f1b638e9a94be80cbd",
            "placeholder": "​",
            "style": "IPY_MODEL_85cac2d243444b52a23f40b87d1b4023",
            "value": " 219/219 [00:44&lt;00:00,  5.12it/s]"
          }
        },
        "36b15a47acca4d19ad19aa7a75d6adc4": {
          "model_module": "@jupyter-widgets/base",
          "model_name": "LayoutModel",
          "model_module_version": "1.2.0",
          "state": {
            "_model_module": "@jupyter-widgets/base",
            "_model_module_version": "1.2.0",
            "_model_name": "LayoutModel",
            "_view_count": null,
            "_view_module": "@jupyter-widgets/base",
            "_view_module_version": "1.2.0",
            "_view_name": "LayoutView",
            "align_content": null,
            "align_items": null,
            "align_self": null,
            "border": null,
            "bottom": null,
            "display": null,
            "flex": null,
            "flex_flow": null,
            "grid_area": null,
            "grid_auto_columns": null,
            "grid_auto_flow": null,
            "grid_auto_rows": null,
            "grid_column": null,
            "grid_gap": null,
            "grid_row": null,
            "grid_template_areas": null,
            "grid_template_columns": null,
            "grid_template_rows": null,
            "height": null,
            "justify_content": null,
            "justify_items": null,
            "left": null,
            "margin": null,
            "max_height": null,
            "max_width": null,
            "min_height": null,
            "min_width": null,
            "object_fit": null,
            "object_position": null,
            "order": null,
            "overflow": null,
            "overflow_x": null,
            "overflow_y": null,
            "padding": null,
            "right": null,
            "top": null,
            "visibility": null,
            "width": null
          }
        },
        "489e34fc92e041f39771ff701ddd6969": {
          "model_module": "@jupyter-widgets/base",
          "model_name": "LayoutModel",
          "model_module_version": "1.2.0",
          "state": {
            "_model_module": "@jupyter-widgets/base",
            "_model_module_version": "1.2.0",
            "_model_name": "LayoutModel",
            "_view_count": null,
            "_view_module": "@jupyter-widgets/base",
            "_view_module_version": "1.2.0",
            "_view_name": "LayoutView",
            "align_content": null,
            "align_items": null,
            "align_self": null,
            "border": null,
            "bottom": null,
            "display": null,
            "flex": null,
            "flex_flow": null,
            "grid_area": null,
            "grid_auto_columns": null,
            "grid_auto_flow": null,
            "grid_auto_rows": null,
            "grid_column": null,
            "grid_gap": null,
            "grid_row": null,
            "grid_template_areas": null,
            "grid_template_columns": null,
            "grid_template_rows": null,
            "height": null,
            "justify_content": null,
            "justify_items": null,
            "left": null,
            "margin": null,
            "max_height": null,
            "max_width": null,
            "min_height": null,
            "min_width": null,
            "object_fit": null,
            "object_position": null,
            "order": null,
            "overflow": null,
            "overflow_x": null,
            "overflow_y": null,
            "padding": null,
            "right": null,
            "top": null,
            "visibility": null,
            "width": null
          }
        },
        "00a24949f2324b8f855ec5bfdc92d434": {
          "model_module": "@jupyter-widgets/controls",
          "model_name": "DescriptionStyleModel",
          "model_module_version": "1.5.0",
          "state": {
            "_model_module": "@jupyter-widgets/controls",
            "_model_module_version": "1.5.0",
            "_model_name": "DescriptionStyleModel",
            "_view_count": null,
            "_view_module": "@jupyter-widgets/base",
            "_view_module_version": "1.2.0",
            "_view_name": "StyleView",
            "description_width": ""
          }
        },
        "ff1caf7de77b4b279a3b6d35669b79dd": {
          "model_module": "@jupyter-widgets/base",
          "model_name": "LayoutModel",
          "model_module_version": "1.2.0",
          "state": {
            "_model_module": "@jupyter-widgets/base",
            "_model_module_version": "1.2.0",
            "_model_name": "LayoutModel",
            "_view_count": null,
            "_view_module": "@jupyter-widgets/base",
            "_view_module_version": "1.2.0",
            "_view_name": "LayoutView",
            "align_content": null,
            "align_items": null,
            "align_self": null,
            "border": null,
            "bottom": null,
            "display": null,
            "flex": null,
            "flex_flow": null,
            "grid_area": null,
            "grid_auto_columns": null,
            "grid_auto_flow": null,
            "grid_auto_rows": null,
            "grid_column": null,
            "grid_gap": null,
            "grid_row": null,
            "grid_template_areas": null,
            "grid_template_columns": null,
            "grid_template_rows": null,
            "height": null,
            "justify_content": null,
            "justify_items": null,
            "left": null,
            "margin": null,
            "max_height": null,
            "max_width": null,
            "min_height": null,
            "min_width": null,
            "object_fit": null,
            "object_position": null,
            "order": null,
            "overflow": null,
            "overflow_x": null,
            "overflow_y": null,
            "padding": null,
            "right": null,
            "top": null,
            "visibility": null,
            "width": null
          }
        },
        "4c653f14317240edbea90f9b198c10cf": {
          "model_module": "@jupyter-widgets/controls",
          "model_name": "ProgressStyleModel",
          "model_module_version": "1.5.0",
          "state": {
            "_model_module": "@jupyter-widgets/controls",
            "_model_module_version": "1.5.0",
            "_model_name": "ProgressStyleModel",
            "_view_count": null,
            "_view_module": "@jupyter-widgets/base",
            "_view_module_version": "1.2.0",
            "_view_name": "StyleView",
            "bar_color": null,
            "description_width": ""
          }
        },
        "a5e6278ca52d47f1b638e9a94be80cbd": {
          "model_module": "@jupyter-widgets/base",
          "model_name": "LayoutModel",
          "model_module_version": "1.2.0",
          "state": {
            "_model_module": "@jupyter-widgets/base",
            "_model_module_version": "1.2.0",
            "_model_name": "LayoutModel",
            "_view_count": null,
            "_view_module": "@jupyter-widgets/base",
            "_view_module_version": "1.2.0",
            "_view_name": "LayoutView",
            "align_content": null,
            "align_items": null,
            "align_self": null,
            "border": null,
            "bottom": null,
            "display": null,
            "flex": null,
            "flex_flow": null,
            "grid_area": null,
            "grid_auto_columns": null,
            "grid_auto_flow": null,
            "grid_auto_rows": null,
            "grid_column": null,
            "grid_gap": null,
            "grid_row": null,
            "grid_template_areas": null,
            "grid_template_columns": null,
            "grid_template_rows": null,
            "height": null,
            "justify_content": null,
            "justify_items": null,
            "left": null,
            "margin": null,
            "max_height": null,
            "max_width": null,
            "min_height": null,
            "min_width": null,
            "object_fit": null,
            "object_position": null,
            "order": null,
            "overflow": null,
            "overflow_x": null,
            "overflow_y": null,
            "padding": null,
            "right": null,
            "top": null,
            "visibility": null,
            "width": null
          }
        },
        "85cac2d243444b52a23f40b87d1b4023": {
          "model_module": "@jupyter-widgets/controls",
          "model_name": "DescriptionStyleModel",
          "model_module_version": "1.5.0",
          "state": {
            "_model_module": "@jupyter-widgets/controls",
            "_model_module_version": "1.5.0",
            "_model_name": "DescriptionStyleModel",
            "_view_count": null,
            "_view_module": "@jupyter-widgets/base",
            "_view_module_version": "1.2.0",
            "_view_name": "StyleView",
            "description_width": ""
          }
        }
      }
    }
  },
  "cells": [
    {
      "cell_type": "markdown",
      "source": [
        "# Binary Classification Using the ChestX-ray14 Dataset"
      ],
      "metadata": {
        "id": "HaDNCcQJ3tD7"
      }
    },
    {
      "cell_type": "markdown",
      "source": [
        "## Step 0: Install PyHealth"
      ],
      "metadata": {
        "id": "j9Zj-n54qEwL"
      }
    },
    {
      "cell_type": "code",
      "source": [
        "!rm -rf PyHealth\n",
        "!git clone https://github.com/EricSchrock/PyHealth.git\n",
        "%cd PyHealth\n",
        "!git checkout ChestX-ray14\n",
        "!pip install -e ."
      ],
      "metadata": {
        "colab": {
          "base_uri": "https://localhost:8080/",
          "height": 1000
        },
        "collapsed": true,
        "id": "TWEAeB85p0C7",
        "outputId": "32a89b86-4c11-49ca-9c46-867dbfaf2fa7"
      },
      "execution_count": 1,
      "outputs": [
        {
          "output_type": "stream",
          "name": "stdout",
          "text": [
            "Cloning into 'PyHealth'...\n",
            "remote: Enumerating objects: 8101, done.\u001b[K\n",
            "remote: Counting objects: 100% (1761/1761), done.\u001b[K\n",
            "remote: Compressing objects: 100% (512/512), done.\u001b[K\n",
            "remote: Total 8101 (delta 1555), reused 1251 (delta 1249), pack-reused 6340 (from 2)\u001b[K\n",
            "Receiving objects: 100% (8101/8101), 113.88 MiB | 26.69 MiB/s, done.\n",
            "Resolving deltas: 100% (5242/5242), done.\n",
            "/content/PyHealth\n",
            "Branch 'ChestX-ray14' set up to track remote branch 'ChestX-ray14' from 'origin'.\n",
            "Switched to a new branch 'ChestX-ray14'\n",
            "Obtaining file:///content/PyHealth\n",
            "  Installing build dependencies ... \u001b[?25l\u001b[?25hdone\n",
            "  Checking if build backend supports build_editable ... \u001b[?25l\u001b[?25hdone\n",
            "  Getting requirements to build editable ... \u001b[?25l\u001b[?25hdone\n",
            "  Installing backend dependencies ... \u001b[?25l\u001b[?25hdone\n",
            "  Preparing editable metadata (pyproject.toml) ... \u001b[?25l\u001b[?25hdone\n",
            "Requirement already satisfied: accelerate in /usr/local/lib/python3.12/dist-packages (from pyhealth==2.0a8) (1.11.0)\n",
            "Collecting mne~=1.10.0 (from pyhealth==2.0a8)\n",
            "  Downloading mne-1.10.2-py3-none-any.whl.metadata (21 kB)\n",
            "Requirement already satisfied: networkx in /usr/local/lib/python3.12/dist-packages (from pyhealth==2.0a8) (3.5)\n",
            "Collecting numpy~=1.26.4 (from pyhealth==2.0a8)\n",
            "  Downloading numpy-1.26.4-cp312-cp312-manylinux_2_17_x86_64.manylinux2014_x86_64.whl.metadata (61 kB)\n",
            "\u001b[2K     \u001b[90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━\u001b[0m \u001b[32m61.0/61.0 kB\u001b[0m \u001b[31m4.8 MB/s\u001b[0m eta \u001b[36m0:00:00\u001b[0m\n",
            "\u001b[?25hCollecting ogb>=1.3.5 (from pyhealth==2.0a8)\n",
            "  Downloading ogb-1.3.6-py3-none-any.whl.metadata (6.2 kB)\n",
            "Collecting pandarallel~=1.6.5 (from pyhealth==2.0a8)\n",
            "  Downloading pandarallel-1.6.5.tar.gz (14 kB)\n",
            "  Preparing metadata (setup.py) ... \u001b[?25l\u001b[?25hdone\n",
            "Collecting pandas~=2.3.1 (from pyhealth==2.0a8)\n",
            "  Downloading pandas-2.3.3-cp312-cp312-manylinux_2_24_x86_64.manylinux_2_28_x86_64.whl.metadata (91 kB)\n",
            "\u001b[2K     \u001b[90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━\u001b[0m \u001b[32m91.2/91.2 kB\u001b[0m \u001b[31m9.0 MB/s\u001b[0m eta \u001b[36m0:00:00\u001b[0m\n",
            "\u001b[?25hRequirement already satisfied: peft in /usr/local/lib/python3.12/dist-packages (from pyhealth==2.0a8) (0.17.1)\n",
            "Requirement already satisfied: polars~=1.31.0 in /usr/local/lib/python3.12/dist-packages (from pyhealth==2.0a8) (1.31.0)\n",
            "Requirement already satisfied: pydantic~=2.11.7 in /usr/local/lib/python3.12/dist-packages (from pyhealth==2.0a8) (2.11.10)\n",
            "Collecting rdkit (from pyhealth==2.0a8)\n",
            "  Downloading rdkit-2025.9.1-cp312-cp312-manylinux_2_28_x86_64.whl.metadata (4.1 kB)\n",
            "Collecting scikit-learn~=1.7.0 (from pyhealth==2.0a8)\n",
            "  Downloading scikit_learn-1.7.2-cp312-cp312-manylinux2014_x86_64.manylinux_2_17_x86_64.whl.metadata (11 kB)\n",
            "Requirement already satisfied: torchvision in /usr/local/lib/python3.12/dist-packages (from pyhealth==2.0a8) (0.23.0+cu126)\n",
            "Collecting torch~=2.7.1 (from pyhealth==2.0a8)\n",
            "  Downloading torch-2.7.1-cp312-cp312-manylinux_2_28_x86_64.whl.metadata (29 kB)\n",
            "Requirement already satisfied: tqdm in /usr/local/lib/python3.12/dist-packages (from pyhealth==2.0a8) (4.67.1)\n",
            "Collecting transformers~=4.53.2 (from pyhealth==2.0a8)\n",
            "  Downloading transformers-4.53.3-py3-none-any.whl.metadata (40 kB)\n",
            "\u001b[2K     \u001b[90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━\u001b[0m \u001b[32m40.9/40.9 kB\u001b[0m \u001b[31m2.6 MB/s\u001b[0m eta \u001b[36m0:00:00\u001b[0m\n",
            "\u001b[?25hRequirement already satisfied: urllib3~=2.5.0 in /usr/local/lib/python3.12/dist-packages (from pyhealth==2.0a8) (2.5.0)\n",
            "Requirement already satisfied: decorator in /usr/local/lib/python3.12/dist-packages (from mne~=1.10.0->pyhealth==2.0a8) (4.4.2)\n",
            "Requirement already satisfied: jinja2 in /usr/local/lib/python3.12/dist-packages (from mne~=1.10.0->pyhealth==2.0a8) (3.1.6)\n",
            "Requirement already satisfied: lazy-loader>=0.3 in /usr/local/lib/python3.12/dist-packages (from mne~=1.10.0->pyhealth==2.0a8) (0.4)\n",
            "Requirement already satisfied: matplotlib>=3.7 in /usr/local/lib/python3.12/dist-packages (from mne~=1.10.0->pyhealth==2.0a8) (3.10.0)\n",
            "Requirement already satisfied: packaging in /usr/local/lib/python3.12/dist-packages (from mne~=1.10.0->pyhealth==2.0a8) (25.0)\n",
            "Requirement already satisfied: pooch>=1.5 in /usr/local/lib/python3.12/dist-packages (from mne~=1.10.0->pyhealth==2.0a8) (1.8.2)\n",
            "Requirement already satisfied: scipy>=1.11 in /usr/local/lib/python3.12/dist-packages (from mne~=1.10.0->pyhealth==2.0a8) (1.16.3)\n",
            "Requirement already satisfied: six>=1.12.0 in /usr/local/lib/python3.12/dist-packages (from ogb>=1.3.5->pyhealth==2.0a8) (1.17.0)\n",
            "Collecting outdated>=0.2.0 (from ogb>=1.3.5->pyhealth==2.0a8)\n",
            "  Downloading outdated-0.2.2-py2.py3-none-any.whl.metadata (4.7 kB)\n",
            "Requirement already satisfied: dill>=0.3.1 in /usr/local/lib/python3.12/dist-packages (from pandarallel~=1.6.5->pyhealth==2.0a8) (0.3.8)\n",
            "Requirement already satisfied: psutil in /usr/local/lib/python3.12/dist-packages (from pandarallel~=1.6.5->pyhealth==2.0a8) (5.9.5)\n",
            "Requirement already satisfied: python-dateutil>=2.8.2 in /usr/local/lib/python3.12/dist-packages (from pandas~=2.3.1->pyhealth==2.0a8) (2.9.0.post0)\n",
            "Requirement already satisfied: pytz>=2020.1 in /usr/local/lib/python3.12/dist-packages (from pandas~=2.3.1->pyhealth==2.0a8) (2025.2)\n",
            "Requirement already satisfied: tzdata>=2022.7 in /usr/local/lib/python3.12/dist-packages (from pandas~=2.3.1->pyhealth==2.0a8) (2025.2)\n",
            "Requirement already satisfied: annotated-types>=0.6.0 in /usr/local/lib/python3.12/dist-packages (from pydantic~=2.11.7->pyhealth==2.0a8) (0.7.0)\n",
            "Requirement already satisfied: pydantic-core==2.33.2 in /usr/local/lib/python3.12/dist-packages (from pydantic~=2.11.7->pyhealth==2.0a8) (2.33.2)\n",
            "Requirement already satisfied: typing-extensions>=4.12.2 in /usr/local/lib/python3.12/dist-packages (from pydantic~=2.11.7->pyhealth==2.0a8) (4.15.0)\n",
            "Requirement already satisfied: typing-inspection>=0.4.0 in /usr/local/lib/python3.12/dist-packages (from pydantic~=2.11.7->pyhealth==2.0a8) (0.4.2)\n",
            "Requirement already satisfied: joblib>=1.2.0 in /usr/local/lib/python3.12/dist-packages (from scikit-learn~=1.7.0->pyhealth==2.0a8) (1.5.2)\n",
            "Requirement already satisfied: threadpoolctl>=3.1.0 in /usr/local/lib/python3.12/dist-packages (from scikit-learn~=1.7.0->pyhealth==2.0a8) (3.6.0)\n",
            "Requirement already satisfied: filelock in /usr/local/lib/python3.12/dist-packages (from torch~=2.7.1->pyhealth==2.0a8) (3.20.0)\n",
            "Requirement already satisfied: setuptools in /usr/local/lib/python3.12/dist-packages (from torch~=2.7.1->pyhealth==2.0a8) (75.2.0)\n",
            "Requirement already satisfied: sympy>=1.13.3 in /usr/local/lib/python3.12/dist-packages (from torch~=2.7.1->pyhealth==2.0a8) (1.13.3)\n",
            "Requirement already satisfied: fsspec in /usr/local/lib/python3.12/dist-packages (from torch~=2.7.1->pyhealth==2.0a8) (2025.3.0)\n",
            "Requirement already satisfied: nvidia-cuda-nvrtc-cu12==12.6.77 in /usr/local/lib/python3.12/dist-packages (from torch~=2.7.1->pyhealth==2.0a8) (12.6.77)\n",
            "Requirement already satisfied: nvidia-cuda-runtime-cu12==12.6.77 in /usr/local/lib/python3.12/dist-packages (from torch~=2.7.1->pyhealth==2.0a8) (12.6.77)\n",
            "Requirement already satisfied: nvidia-cuda-cupti-cu12==12.6.80 in /usr/local/lib/python3.12/dist-packages (from torch~=2.7.1->pyhealth==2.0a8) (12.6.80)\n",
            "Collecting nvidia-cudnn-cu12==9.5.1.17 (from torch~=2.7.1->pyhealth==2.0a8)\n",
            "  Downloading nvidia_cudnn_cu12-9.5.1.17-py3-none-manylinux_2_28_x86_64.whl.metadata (1.6 kB)\n",
            "Requirement already satisfied: nvidia-cublas-cu12==12.6.4.1 in /usr/local/lib/python3.12/dist-packages (from torch~=2.7.1->pyhealth==2.0a8) (12.6.4.1)\n",
            "Requirement already satisfied: nvidia-cufft-cu12==11.3.0.4 in /usr/local/lib/python3.12/dist-packages (from torch~=2.7.1->pyhealth==2.0a8) (11.3.0.4)\n",
            "Requirement already satisfied: nvidia-curand-cu12==10.3.7.77 in /usr/local/lib/python3.12/dist-packages (from torch~=2.7.1->pyhealth==2.0a8) (10.3.7.77)\n",
            "Requirement already satisfied: nvidia-cusolver-cu12==11.7.1.2 in /usr/local/lib/python3.12/dist-packages (from torch~=2.7.1->pyhealth==2.0a8) (11.7.1.2)\n",
            "Requirement already satisfied: nvidia-cusparse-cu12==12.5.4.2 in /usr/local/lib/python3.12/dist-packages (from torch~=2.7.1->pyhealth==2.0a8) (12.5.4.2)\n",
            "Collecting nvidia-cusparselt-cu12==0.6.3 (from torch~=2.7.1->pyhealth==2.0a8)\n",
            "  Downloading nvidia_cusparselt_cu12-0.6.3-py3-none-manylinux2014_x86_64.whl.metadata (6.8 kB)\n",
            "Collecting nvidia-nccl-cu12==2.26.2 (from torch~=2.7.1->pyhealth==2.0a8)\n",
            "  Downloading nvidia_nccl_cu12-2.26.2-py3-none-manylinux2014_x86_64.manylinux_2_17_x86_64.whl.metadata (2.0 kB)\n",
            "Requirement already satisfied: nvidia-nvtx-cu12==12.6.77 in /usr/local/lib/python3.12/dist-packages (from torch~=2.7.1->pyhealth==2.0a8) (12.6.77)\n",
            "Requirement already satisfied: nvidia-nvjitlink-cu12==12.6.85 in /usr/local/lib/python3.12/dist-packages (from torch~=2.7.1->pyhealth==2.0a8) (12.6.85)\n",
            "Requirement already satisfied: nvidia-cufile-cu12==1.11.1.6 in /usr/local/lib/python3.12/dist-packages (from torch~=2.7.1->pyhealth==2.0a8) (1.11.1.6)\n",
            "Collecting triton==3.3.1 (from torch~=2.7.1->pyhealth==2.0a8)\n",
            "  Downloading triton-3.3.1-cp312-cp312-manylinux_2_27_x86_64.manylinux_2_28_x86_64.whl.metadata (1.5 kB)\n",
            "Requirement already satisfied: huggingface-hub<1.0,>=0.30.0 in /usr/local/lib/python3.12/dist-packages (from transformers~=4.53.2->pyhealth==2.0a8) (0.36.0)\n",
            "Requirement already satisfied: pyyaml>=5.1 in /usr/local/lib/python3.12/dist-packages (from transformers~=4.53.2->pyhealth==2.0a8) (6.0.3)\n",
            "Requirement already satisfied: regex!=2019.12.17 in /usr/local/lib/python3.12/dist-packages (from transformers~=4.53.2->pyhealth==2.0a8) (2024.11.6)\n",
            "Requirement already satisfied: requests in /usr/local/lib/python3.12/dist-packages (from transformers~=4.53.2->pyhealth==2.0a8) (2.32.4)\n",
            "Collecting tokenizers<0.22,>=0.21 (from transformers~=4.53.2->pyhealth==2.0a8)\n",
            "  Downloading tokenizers-0.21.4-cp39-abi3-manylinux_2_17_x86_64.manylinux2014_x86_64.whl.metadata (6.7 kB)\n",
            "Requirement already satisfied: safetensors>=0.4.3 in /usr/local/lib/python3.12/dist-packages (from transformers~=4.53.2->pyhealth==2.0a8) (0.6.2)\n",
            "Requirement already satisfied: Pillow in /usr/local/lib/python3.12/dist-packages (from rdkit->pyhealth==2.0a8) (11.3.0)\n",
            "INFO: pip is looking at multiple versions of torchvision to determine which version is compatible with other requirements. This could take a while.\n",
            "Collecting torchvision (from pyhealth==2.0a8)\n",
            "  Downloading torchvision-0.24.1-cp312-cp312-manylinux_2_28_x86_64.whl.metadata (5.9 kB)\n",
            "  Downloading torchvision-0.24.0-cp312-cp312-manylinux_2_28_x86_64.whl.metadata (5.9 kB)\n",
            "  Downloading torchvision-0.23.0-cp312-cp312-manylinux_2_28_x86_64.whl.metadata (6.1 kB)\n",
            "  Downloading torchvision-0.22.1-cp312-cp312-manylinux_2_28_x86_64.whl.metadata (6.1 kB)\n",
            "Requirement already satisfied: hf-xet<2.0.0,>=1.1.3 in /usr/local/lib/python3.12/dist-packages (from huggingface-hub<1.0,>=0.30.0->transformers~=4.53.2->pyhealth==2.0a8) (1.2.0)\n",
            "Requirement already satisfied: contourpy>=1.0.1 in /usr/local/lib/python3.12/dist-packages (from matplotlib>=3.7->mne~=1.10.0->pyhealth==2.0a8) (1.3.3)\n",
            "Requirement already satisfied: cycler>=0.10 in /usr/local/lib/python3.12/dist-packages (from matplotlib>=3.7->mne~=1.10.0->pyhealth==2.0a8) (0.12.1)\n",
            "Requirement already satisfied: fonttools>=4.22.0 in /usr/local/lib/python3.12/dist-packages (from matplotlib>=3.7->mne~=1.10.0->pyhealth==2.0a8) (4.60.1)\n",
            "Requirement already satisfied: kiwisolver>=1.3.1 in /usr/local/lib/python3.12/dist-packages (from matplotlib>=3.7->mne~=1.10.0->pyhealth==2.0a8) (1.4.9)\n",
            "Requirement already satisfied: pyparsing>=2.3.1 in /usr/local/lib/python3.12/dist-packages (from matplotlib>=3.7->mne~=1.10.0->pyhealth==2.0a8) (3.2.5)\n",
            "Collecting littleutils (from outdated>=0.2.0->ogb>=1.3.5->pyhealth==2.0a8)\n",
            "  Downloading littleutils-0.2.4-py3-none-any.whl.metadata (679 bytes)\n",
            "Requirement already satisfied: platformdirs>=2.5.0 in /usr/local/lib/python3.12/dist-packages (from pooch>=1.5->mne~=1.10.0->pyhealth==2.0a8) (4.5.0)\n",
            "Requirement already satisfied: charset_normalizer<4,>=2 in /usr/local/lib/python3.12/dist-packages (from requests->transformers~=4.53.2->pyhealth==2.0a8) (3.4.4)\n",
            "Requirement already satisfied: idna<4,>=2.5 in /usr/local/lib/python3.12/dist-packages (from requests->transformers~=4.53.2->pyhealth==2.0a8) (3.11)\n",
            "Requirement already satisfied: certifi>=2017.4.17 in /usr/local/lib/python3.12/dist-packages (from requests->transformers~=4.53.2->pyhealth==2.0a8) (2025.10.5)\n",
            "Requirement already satisfied: mpmath<1.4,>=1.1.0 in /usr/local/lib/python3.12/dist-packages (from sympy>=1.13.3->torch~=2.7.1->pyhealth==2.0a8) (1.3.0)\n",
            "Requirement already satisfied: MarkupSafe>=2.0 in /usr/local/lib/python3.12/dist-packages (from jinja2->mne~=1.10.0->pyhealth==2.0a8) (3.0.3)\n",
            "Downloading mne-1.10.2-py3-none-any.whl (7.4 MB)\n",
            "\u001b[2K   \u001b[90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━\u001b[0m \u001b[32m7.4/7.4 MB\u001b[0m \u001b[31m100.8 MB/s\u001b[0m eta \u001b[36m0:00:00\u001b[0m\n",
            "\u001b[?25hDownloading numpy-1.26.4-cp312-cp312-manylinux_2_17_x86_64.manylinux2014_x86_64.whl (18.0 MB)\n",
            "\u001b[2K   \u001b[90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━\u001b[0m \u001b[32m18.0/18.0 MB\u001b[0m \u001b[31m125.8 MB/s\u001b[0m eta \u001b[36m0:00:00\u001b[0m\n",
            "\u001b[?25hDownloading ogb-1.3.6-py3-none-any.whl (78 kB)\n",
            "\u001b[2K   \u001b[90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━\u001b[0m \u001b[32m78.8/78.8 kB\u001b[0m \u001b[31m8.8 MB/s\u001b[0m eta \u001b[36m0:00:00\u001b[0m\n",
            "\u001b[?25hDownloading pandas-2.3.3-cp312-cp312-manylinux_2_24_x86_64.manylinux_2_28_x86_64.whl (12.4 MB)\n",
            "\u001b[2K   \u001b[90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━\u001b[0m \u001b[32m12.4/12.4 MB\u001b[0m \u001b[31m145.6 MB/s\u001b[0m eta \u001b[36m0:00:00\u001b[0m\n",
            "\u001b[?25hDownloading scikit_learn-1.7.2-cp312-cp312-manylinux2014_x86_64.manylinux_2_17_x86_64.whl (9.5 MB)\n",
            "\u001b[2K   \u001b[90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━\u001b[0m \u001b[32m9.5/9.5 MB\u001b[0m \u001b[31m145.4 MB/s\u001b[0m eta \u001b[36m0:00:00\u001b[0m\n",
            "\u001b[?25hDownloading torch-2.7.1-cp312-cp312-manylinux_2_28_x86_64.whl (821.0 MB)\n",
            "\u001b[2K   \u001b[90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━\u001b[0m \u001b[32m821.0/821.0 MB\u001b[0m \u001b[31m1.1 MB/s\u001b[0m eta \u001b[36m0:00:00\u001b[0m\n",
            "\u001b[?25hDownloading nvidia_cudnn_cu12-9.5.1.17-py3-none-manylinux_2_28_x86_64.whl (571.0 MB)\n",
            "\u001b[2K   \u001b[90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━\u001b[0m \u001b[32m571.0/571.0 MB\u001b[0m \u001b[31m3.5 MB/s\u001b[0m eta \u001b[36m0:00:00\u001b[0m\n",
            "\u001b[?25hDownloading nvidia_cusparselt_cu12-0.6.3-py3-none-manylinux2014_x86_64.whl (156.8 MB)\n",
            "\u001b[2K   \u001b[90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━\u001b[0m \u001b[32m156.8/156.8 MB\u001b[0m \u001b[31m7.1 MB/s\u001b[0m eta \u001b[36m0:00:00\u001b[0m\n",
            "\u001b[?25hDownloading nvidia_nccl_cu12-2.26.2-py3-none-manylinux2014_x86_64.manylinux_2_17_x86_64.whl (201.3 MB)\n",
            "\u001b[2K   \u001b[90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━\u001b[0m \u001b[32m201.3/201.3 MB\u001b[0m \u001b[31m6.0 MB/s\u001b[0m eta \u001b[36m0:00:00\u001b[0m\n",
            "\u001b[?25hDownloading triton-3.3.1-cp312-cp312-manylinux_2_27_x86_64.manylinux_2_28_x86_64.whl (155.7 MB)\n",
            "\u001b[2K   \u001b[90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━\u001b[0m \u001b[32m155.7/155.7 MB\u001b[0m \u001b[31m8.3 MB/s\u001b[0m eta \u001b[36m0:00:00\u001b[0m\n",
            "\u001b[?25hDownloading transformers-4.53.3-py3-none-any.whl (10.8 MB)\n",
            "\u001b[2K   \u001b[90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━\u001b[0m \u001b[32m10.8/10.8 MB\u001b[0m \u001b[31m133.0 MB/s\u001b[0m eta \u001b[36m0:00:00\u001b[0m\n",
            "\u001b[?25hDownloading rdkit-2025.9.1-cp312-cp312-manylinux_2_28_x86_64.whl (36.2 MB)\n",
            "\u001b[2K   \u001b[90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━\u001b[0m \u001b[32m36.2/36.2 MB\u001b[0m \u001b[31m20.2 MB/s\u001b[0m eta \u001b[36m0:00:00\u001b[0m\n",
            "\u001b[?25hDownloading torchvision-0.22.1-cp312-cp312-manylinux_2_28_x86_64.whl (7.5 MB)\n",
            "\u001b[2K   \u001b[90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━\u001b[0m \u001b[32m7.5/7.5 MB\u001b[0m \u001b[31m84.7 MB/s\u001b[0m eta \u001b[36m0:00:00\u001b[0m\n",
            "\u001b[?25hDownloading outdated-0.2.2-py2.py3-none-any.whl (7.5 kB)\n",
            "Downloading tokenizers-0.21.4-cp39-abi3-manylinux_2_17_x86_64.manylinux2014_x86_64.whl (3.1 MB)\n",
            "\u001b[2K   \u001b[90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━\u001b[0m \u001b[32m3.1/3.1 MB\u001b[0m \u001b[31m63.7 MB/s\u001b[0m eta \u001b[36m0:00:00\u001b[0m\n",
            "\u001b[?25hDownloading littleutils-0.2.4-py3-none-any.whl (8.1 kB)\n",
            "Building wheels for collected packages: pyhealth, pandarallel\n",
            "  Building editable for pyhealth (pyproject.toml) ... \u001b[?25l\u001b[?25hdone\n",
            "  Created wheel for pyhealth: filename=pyhealth-2.0a8-py3-none-any.whl size=10674 sha256=958c7e0bd8938910e22eda0840e62272710f8cae2e42ad8531f1012a34cd222f\n",
            "  Stored in directory: /tmp/pip-ephem-wheel-cache-c1tiyeqt/wheels/1c/98/da/d6e74a692d0be5faeba6025d7302fd470b1ee8167b77261ad6\n",
            "  Building wheel for pandarallel (setup.py) ... \u001b[?25l\u001b[?25hdone\n",
            "  Created wheel for pandarallel: filename=pandarallel-1.6.5-py3-none-any.whl size=16674 sha256=d2ad066c2563268e9811ae2c6adb46d872232aa5ad8891689caf3aea26b89d42\n",
            "  Stored in directory: /root/.cache/pip/wheels/46/f9/0d/40c9cd74a7cb8dc8fe57e8d6c3c19e2c730449c0d3f2bf66b5\n",
            "Successfully built pyhealth pandarallel\n",
            "Installing collected packages: nvidia-cusparselt-cu12, triton, nvidia-nccl-cu12, nvidia-cudnn-cu12, numpy, littleutils, rdkit, pandas, outdated, torch, tokenizers, scikit-learn, pandarallel, transformers, torchvision, ogb, mne, pyhealth\n",
            "  Attempting uninstall: nvidia-cusparselt-cu12\n",
            "    Found existing installation: nvidia-cusparselt-cu12 0.7.1\n",
            "    Uninstalling nvidia-cusparselt-cu12-0.7.1:\n",
            "      Successfully uninstalled nvidia-cusparselt-cu12-0.7.1\n",
            "  Attempting uninstall: triton\n",
            "    Found existing installation: triton 3.4.0\n",
            "    Uninstalling triton-3.4.0:\n",
            "      Successfully uninstalled triton-3.4.0\n",
            "  Attempting uninstall: nvidia-nccl-cu12\n",
            "    Found existing installation: nvidia-nccl-cu12 2.27.3\n",
            "    Uninstalling nvidia-nccl-cu12-2.27.3:\n",
            "      Successfully uninstalled nvidia-nccl-cu12-2.27.3\n",
            "  Attempting uninstall: nvidia-cudnn-cu12\n",
            "    Found existing installation: nvidia-cudnn-cu12 9.10.2.21\n",
            "    Uninstalling nvidia-cudnn-cu12-9.10.2.21:\n",
            "      Successfully uninstalled nvidia-cudnn-cu12-9.10.2.21\n",
            "  Attempting uninstall: numpy\n",
            "    Found existing installation: numpy 2.0.2\n",
            "    Uninstalling numpy-2.0.2:\n",
            "      Successfully uninstalled numpy-2.0.2\n",
            "  Attempting uninstall: pandas\n",
            "    Found existing installation: pandas 2.2.2\n",
            "    Uninstalling pandas-2.2.2:\n",
            "      Successfully uninstalled pandas-2.2.2\n",
            "  Attempting uninstall: torch\n",
            "    Found existing installation: torch 2.8.0+cu126\n",
            "    Uninstalling torch-2.8.0+cu126:\n",
            "      Successfully uninstalled torch-2.8.0+cu126\n",
            "  Attempting uninstall: tokenizers\n",
            "    Found existing installation: tokenizers 0.22.1\n",
            "    Uninstalling tokenizers-0.22.1:\n",
            "      Successfully uninstalled tokenizers-0.22.1\n",
            "  Attempting uninstall: scikit-learn\n",
            "    Found existing installation: scikit-learn 1.6.1\n",
            "    Uninstalling scikit-learn-1.6.1:\n",
            "      Successfully uninstalled scikit-learn-1.6.1\n",
            "  Attempting uninstall: transformers\n",
            "    Found existing installation: transformers 4.57.1\n",
            "    Uninstalling transformers-4.57.1:\n",
            "      Successfully uninstalled transformers-4.57.1\n",
            "  Attempting uninstall: torchvision\n",
            "    Found existing installation: torchvision 0.23.0+cu126\n",
            "    Uninstalling torchvision-0.23.0+cu126:\n",
            "      Successfully uninstalled torchvision-0.23.0+cu126\n",
            "\u001b[31mERROR: pip's dependency resolver does not currently take into account all the packages that are installed. This behaviour is the source of the following dependency conflicts.\n",
            "google-colab 1.0.0 requires pandas==2.2.2, but you have pandas 2.3.3 which is incompatible.\n",
            "pytensor 2.35.1 requires numpy>=2.0, but you have numpy 1.26.4 which is incompatible.\n",
            "opencv-contrib-python 4.12.0.88 requires numpy<2.3.0,>=2; python_version >= \"3.9\", but you have numpy 1.26.4 which is incompatible.\n",
            "shap 0.50.0 requires numpy>=2, but you have numpy 1.26.4 which is incompatible.\n",
            "torchaudio 2.8.0+cu126 requires torch==2.8.0, but you have torch 2.7.1 which is incompatible.\n",
            "opencv-python 4.12.0.88 requires numpy<2.3.0,>=2; python_version >= \"3.9\", but you have numpy 1.26.4 which is incompatible.\n",
            "jax 0.7.2 requires numpy>=2.0, but you have numpy 1.26.4 which is incompatible.\n",
            "opencv-python-headless 4.12.0.88 requires numpy<2.3.0,>=2; python_version >= \"3.9\", but you have numpy 1.26.4 which is incompatible.\n",
            "jaxlib 0.7.2 requires numpy>=2.0, but you have numpy 1.26.4 which is incompatible.\u001b[0m\u001b[31m\n",
            "\u001b[0mSuccessfully installed littleutils-0.2.4 mne-1.10.2 numpy-1.26.4 nvidia-cudnn-cu12-9.5.1.17 nvidia-cusparselt-cu12-0.6.3 nvidia-nccl-cu12-2.26.2 ogb-1.3.6 outdated-0.2.2 pandarallel-1.6.5 pandas-2.3.3 pyhealth-2.0a8 rdkit-2025.9.1 scikit-learn-1.7.2 tokenizers-0.21.4 torch-2.7.1 torchvision-0.22.1 transformers-4.53.3 triton-3.3.1\n"
          ]
        },
        {
          "output_type": "display_data",
          "data": {
            "application/vnd.colab-display-data+json": {
              "pip_warning": {
                "packages": [
                  "numpy"
                ]
              },
              "id": "3737617eb2cf402699bacea64f559c14"
            }
          },
          "metadata": {}
        }
      ]
    },
    {
      "cell_type": "markdown",
      "source": [
        "## Step 1: Load Dataset"
      ],
      "metadata": {
        "id": "rMjzPqNbscDV"
      }
    },
    {
      "cell_type": "code",
      "source": [
        "from pyhealth.datasets import ChestXray14Dataset\n",
        "\n",
        "dataset = ChestXray14Dataset(download=True, partial=True)\n",
        "dataset.stats()"
      ],
      "metadata": {
        "colab": {
          "base_uri": "https://localhost:8080/"
        },
        "id": "q_fTVUTrsryn",
        "outputId": "0660d909-31c6-48df-bb98-a015e48dd88d"
      },
      "execution_count": 1,
      "outputs": [
        {
          "output_type": "stream",
          "name": "stdout",
          "text": [
            "Downloading ./images_01.tar.gz...\n"
          ]
        },
        {
          "output_type": "stream",
          "name": "stderr",
          "text": [
            "INFO:pyhealth.datasets.chestxray14:Downloading ./images_01.tar.gz...\n"
          ]
        },
        {
          "output_type": "stream",
          "name": "stdout",
          "text": [
            "Checking MD5 checksum for ./images_01.tar.gz...\n"
          ]
        },
        {
          "output_type": "stream",
          "name": "stderr",
          "text": [
            "INFO:pyhealth.datasets.chestxray14:Checking MD5 checksum for ./images_01.tar.gz...\n"
          ]
        },
        {
          "output_type": "stream",
          "name": "stdout",
          "text": [
            "Extracting ./images_01.tar.gz...\n"
          ]
        },
        {
          "output_type": "stream",
          "name": "stderr",
          "text": [
            "INFO:pyhealth.datasets.chestxray14:Extracting ./images_01.tar.gz...\n"
          ]
        },
        {
          "output_type": "stream",
          "name": "stdout",
          "text": [
            "Deleting ./images_01.tar.gz...\n"
          ]
        },
        {
          "output_type": "stream",
          "name": "stderr",
          "text": [
            "INFO:pyhealth.datasets.chestxray14:Deleting ./images_01.tar.gz...\n"
          ]
        },
        {
          "output_type": "stream",
          "name": "stdout",
          "text": [
            "Download complete\n"
          ]
        },
        {
          "output_type": "stream",
          "name": "stderr",
          "text": [
            "INFO:pyhealth.datasets.chestxray14:Download complete\n"
          ]
        },
        {
          "output_type": "stream",
          "name": "stdout",
          "text": [
            "Initializing ChestX-ray14 dataset from . (dev mode: False)\n"
          ]
        },
        {
          "output_type": "stream",
          "name": "stderr",
          "text": [
            "INFO:pyhealth.datasets.base_dataset:Initializing ChestX-ray14 dataset from . (dev mode: False)\n"
          ]
        },
        {
          "output_type": "stream",
          "name": "stdout",
          "text": [
            "Scanning table: chestxray14 from /content/chestxray14-metadata-pyhealth.csv\n"
          ]
        },
        {
          "output_type": "stream",
          "name": "stderr",
          "text": [
            "INFO:pyhealth.datasets.base_dataset:Scanning table: chestxray14 from /content/chestxray14-metadata-pyhealth.csv\n"
          ]
        },
        {
          "output_type": "stream",
          "name": "stdout",
          "text": [
            "Collecting global event dataframe...\n"
          ]
        },
        {
          "output_type": "stream",
          "name": "stderr",
          "text": [
            "INFO:pyhealth.datasets.base_dataset:Collecting global event dataframe...\n"
          ]
        },
        {
          "output_type": "stream",
          "name": "stdout",
          "text": [
            "Collected dataframe with shape: (4999, 26)\n"
          ]
        },
        {
          "output_type": "stream",
          "name": "stderr",
          "text": [
            "INFO:pyhealth.datasets.base_dataset:Collected dataframe with shape: (4999, 26)\n"
          ]
        },
        {
          "output_type": "stream",
          "name": "stdout",
          "text": [
            "Dataset: ChestX-ray14\n",
            "Dev mode: False\n",
            "Number of patients: 1335\n",
            "Number of events: 4999\n"
          ]
        }
      ]
    },
    {
      "cell_type": "markdown",
      "source": [
        "## Step 2: Define Task"
      ],
      "metadata": {
        "id": "ecF9IgCb22N5"
      }
    },
    {
      "cell_type": "code",
      "source": [
        "from pyhealth.tasks import ChestXray14BinaryClassification\n",
        "\n",
        "task = ChestXray14BinaryClassification(disease=\"infiltration\")\n",
        "samples = dataset.set_task(task)"
      ],
      "metadata": {
        "colab": {
          "base_uri": "https://localhost:8080/"
        },
        "id": "uj9ALkQGtVqF",
        "outputId": "bfb6c953-9411-40be-9ca0-060077cbad96"
      },
      "execution_count": 2,
      "outputs": [
        {
          "output_type": "stream",
          "name": "stdout",
          "text": [
            "Setting task ChestXray14BinaryClassification for ChestX-ray14 base dataset...\n"
          ]
        },
        {
          "output_type": "stream",
          "name": "stderr",
          "text": [
            "INFO:pyhealth.datasets.base_dataset:Setting task ChestXray14BinaryClassification for ChestX-ray14 base dataset...\n"
          ]
        },
        {
          "output_type": "stream",
          "name": "stdout",
          "text": [
            "Generating samples with 1 worker(s)...\n"
          ]
        },
        {
          "output_type": "stream",
          "name": "stderr",
          "text": [
            "INFO:pyhealth.datasets.base_dataset:Generating samples with 1 worker(s)...\n",
            "Generating samples for ChestXray14BinaryClassification with 1 worker: 100%|██████████| 1335/1335 [00:00<00:00, 1770.85it/s]"
          ]
        },
        {
          "output_type": "stream",
          "name": "stdout",
          "text": [
            "Label label vocab: {0: 0, 1: 1}\n"
          ]
        },
        {
          "output_type": "stream",
          "name": "stderr",
          "text": [
            "\n",
            "INFO:pyhealth.processors.label_processor:Label label vocab: {0: 0, 1: 1}\n",
            "Processing samples: 100%|██████████| 4999/4999 [01:22<00:00, 60.94it/s]"
          ]
        },
        {
          "output_type": "stream",
          "name": "stdout",
          "text": [
            "Generated 4999 samples for task ChestXray14BinaryClassification\n"
          ]
        },
        {
          "output_type": "stream",
          "name": "stderr",
          "text": [
            "\n",
            "INFO:pyhealth.datasets.base_dataset:Generated 4999 samples for task ChestXray14BinaryClassification\n"
          ]
        }
      ]
    },
    {
      "cell_type": "code",
      "source": [
        "from pyhealth.datasets import get_dataloader, split_by_sample\n",
        "\n",
        "train_dataset, val_dataset, test_dataset = split_by_sample(samples, [0.7, 0.1, 0.2])\n",
        "\n",
        "train_loader = get_dataloader(train_dataset, batch_size=16, shuffle=True)\n",
        "val_loader = get_dataloader(val_dataset, batch_size=16, shuffle=False)\n",
        "test_loader = get_dataloader(test_dataset, batch_size=16, shuffle=False)"
      ],
      "metadata": {
        "id": "8qS3hfKX5GNo"
      },
      "execution_count": 3,
      "outputs": []
    },
    {
      "cell_type": "markdown",
      "source": [
        "## Step 3: Define Model"
      ],
      "metadata": {
        "id": "SjonWePy1r6N"
      }
    },
    {
      "cell_type": "code",
      "source": [
        "from pyhealth.models import CNN\n",
        "\n",
        "model = CNN(dataset=samples)"
      ],
      "metadata": {
        "colab": {
          "base_uri": "https://localhost:8080/"
        },
        "id": "VydSOr8u0XWG",
        "outputId": "9f3bf251-0b5d-4457-9dfe-14f5d1beaeb5"
      },
      "execution_count": 4,
      "outputs": [
        {
          "output_type": "stream",
          "name": "stderr",
          "text": [
            "/content/PyHealth/pyhealth/metrics/calibration.py:122: SyntaxWarning: invalid escape sequence '\\c'\n",
            "  accuracy of 1. Thus, the ECE is :math:`\\\\frac{1}{3} \\cdot 0.49 + \\\\frac{2}{3}\\cdot 0.3=0.3633`.\n"
          ]
        },
        {
          "output_type": "stream",
          "name": "stdout",
          "text": [
            "Warning: No embedding created for field due to lack of compatible processor: image\n"
          ]
        }
      ]
    },
    {
      "cell_type": "markdown",
      "source": [
        "## Step 4: Train Model"
      ],
      "metadata": {
        "id": "0jqDpKxgAu3-"
      }
    },
    {
      "cell_type": "code",
      "source": [
        "from pyhealth.trainer import Trainer\n",
        "\n",
        "trainer = Trainer(model=model)\n",
        "trainer.train(train_dataloader=train_loader, val_dataloader=val_loader, epochs=1)"
      ],
      "metadata": {
        "colab": {
          "base_uri": "https://localhost:8080/",
          "height": 1000,
          "referenced_widgets": [
            "2750023fb2bc420c875b3fde2cef2843",
            "d942639eecdf4f3c955a5ceabb2dd012",
            "21aed97b90ea496dafbbfde642b8da3d",
            "c38852bb4f82461dbc2f2ad28949e04f",
            "36b15a47acca4d19ad19aa7a75d6adc4",
            "489e34fc92e041f39771ff701ddd6969",
            "00a24949f2324b8f855ec5bfdc92d434",
            "ff1caf7de77b4b279a3b6d35669b79dd",
            "4c653f14317240edbea90f9b198c10cf",
            "a5e6278ca52d47f1b638e9a94be80cbd",
            "85cac2d243444b52a23f40b87d1b4023"
          ]
        },
        "id": "-our6gpdAyGD",
        "outputId": "ac84d73f-9940-4333-8f7e-f7b9f2614da9"
      },
      "execution_count": 5,
      "outputs": [
        {
          "output_type": "stream",
          "name": "stdout",
          "text": [
            "CNN(\n",
            "  (embedding_model): EmbeddingModel(embedding_layers=ModuleDict())\n",
            "  (cnn): ModuleDict(\n",
            "    (image): CNNLayer(\n",
            "      (cnn): ModuleList(\n",
            "        (0): CNNBlock(\n",
            "          (conv1): Sequential(\n",
            "            (0): Conv2d(1, 128, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))\n",
            "            (1): BatchNorm2d(128, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)\n",
            "            (2): ReLU()\n",
            "          )\n",
            "          (conv2): Sequential(\n",
            "            (0): Conv2d(128, 128, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))\n",
            "            (1): BatchNorm2d(128, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)\n",
            "          )\n",
            "          (downsample): Sequential(\n",
            "            (0): Conv2d(1, 128, kernel_size=(1, 1), stride=(1, 1))\n",
            "            (1): BatchNorm2d(128, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)\n",
            "          )\n",
            "          (relu): ReLU()\n",
            "        )\n",
            "      )\n",
            "      (pooling): AdaptiveAvgPool2d(output_size=1)\n",
            "    )\n",
            "  )\n",
            "  (fc): Linear(in_features=128, out_features=1, bias=True)\n",
            ")\n"
          ]
        },
        {
          "output_type": "stream",
          "name": "stderr",
          "text": [
            "INFO:pyhealth.trainer:CNN(\n",
            "  (embedding_model): EmbeddingModel(embedding_layers=ModuleDict())\n",
            "  (cnn): ModuleDict(\n",
            "    (image): CNNLayer(\n",
            "      (cnn): ModuleList(\n",
            "        (0): CNNBlock(\n",
            "          (conv1): Sequential(\n",
            "            (0): Conv2d(1, 128, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))\n",
            "            (1): BatchNorm2d(128, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)\n",
            "            (2): ReLU()\n",
            "          )\n",
            "          (conv2): Sequential(\n",
            "            (0): Conv2d(128, 128, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))\n",
            "            (1): BatchNorm2d(128, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)\n",
            "          )\n",
            "          (downsample): Sequential(\n",
            "            (0): Conv2d(1, 128, kernel_size=(1, 1), stride=(1, 1))\n",
            "            (1): BatchNorm2d(128, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)\n",
            "          )\n",
            "          (relu): ReLU()\n",
            "        )\n",
            "      )\n",
            "      (pooling): AdaptiveAvgPool2d(output_size=1)\n",
            "    )\n",
            "  )\n",
            "  (fc): Linear(in_features=128, out_features=1, bias=True)\n",
            ")\n"
          ]
        },
        {
          "output_type": "stream",
          "name": "stdout",
          "text": [
            "Metrics: None\n"
          ]
        },
        {
          "output_type": "stream",
          "name": "stderr",
          "text": [
            "INFO:pyhealth.trainer:Metrics: None\n"
          ]
        },
        {
          "output_type": "stream",
          "name": "stdout",
          "text": [
            "Device: cuda\n"
          ]
        },
        {
          "output_type": "stream",
          "name": "stderr",
          "text": [
            "INFO:pyhealth.trainer:Device: cuda\n"
          ]
        },
        {
          "output_type": "stream",
          "name": "stdout",
          "text": [
            "\n"
          ]
        },
        {
          "output_type": "stream",
          "name": "stderr",
          "text": [
            "INFO:pyhealth.trainer:\n"
          ]
        },
        {
          "output_type": "stream",
          "name": "stdout",
          "text": [
            "Training:\n"
          ]
        },
        {
          "output_type": "stream",
          "name": "stderr",
          "text": [
            "INFO:pyhealth.trainer:Training:\n"
          ]
        },
        {
          "output_type": "stream",
          "name": "stdout",
          "text": [
            "Batch size: 16\n"
          ]
        },
        {
          "output_type": "stream",
          "name": "stderr",
          "text": [
            "INFO:pyhealth.trainer:Batch size: 16\n"
          ]
        },
        {
          "output_type": "stream",
          "name": "stdout",
          "text": [
            "Optimizer: <class 'torch.optim.adam.Adam'>\n"
          ]
        },
        {
          "output_type": "stream",
          "name": "stderr",
          "text": [
            "INFO:pyhealth.trainer:Optimizer: <class 'torch.optim.adam.Adam'>\n"
          ]
        },
        {
          "output_type": "stream",
          "name": "stdout",
          "text": [
            "Optimizer params: {'lr': 0.001}\n"
          ]
        },
        {
          "output_type": "stream",
          "name": "stderr",
          "text": [
            "INFO:pyhealth.trainer:Optimizer params: {'lr': 0.001}\n"
          ]
        },
        {
          "output_type": "stream",
          "name": "stdout",
          "text": [
            "Weight decay: 0.0\n"
          ]
        },
        {
          "output_type": "stream",
          "name": "stderr",
          "text": [
            "INFO:pyhealth.trainer:Weight decay: 0.0\n"
          ]
        },
        {
          "output_type": "stream",
          "name": "stdout",
          "text": [
            "Max grad norm: None\n"
          ]
        },
        {
          "output_type": "stream",
          "name": "stderr",
          "text": [
            "INFO:pyhealth.trainer:Max grad norm: None\n"
          ]
        },
        {
          "output_type": "stream",
          "name": "stdout",
          "text": [
            "Val dataloader: <torch.utils.data.dataloader.DataLoader object at 0x7b10bd7e3c50>\n"
          ]
        },
        {
          "output_type": "stream",
          "name": "stderr",
          "text": [
            "INFO:pyhealth.trainer:Val dataloader: <torch.utils.data.dataloader.DataLoader object at 0x7b10bd7e3c50>\n"
          ]
        },
        {
          "output_type": "stream",
          "name": "stdout",
          "text": [
            "Monitor: None\n"
          ]
        },
        {
          "output_type": "stream",
          "name": "stderr",
          "text": [
            "INFO:pyhealth.trainer:Monitor: None\n"
          ]
        },
        {
          "output_type": "stream",
          "name": "stdout",
          "text": [
            "Monitor criterion: max\n"
          ]
        },
        {
          "output_type": "stream",
          "name": "stderr",
          "text": [
            "INFO:pyhealth.trainer:Monitor criterion: max\n"
          ]
        },
        {
          "output_type": "stream",
          "name": "stdout",
          "text": [
            "Epochs: 1\n"
          ]
        },
        {
          "output_type": "stream",
          "name": "stderr",
          "text": [
            "INFO:pyhealth.trainer:Epochs: 1\n"
          ]
        },
        {
          "output_type": "stream",
          "name": "stdout",
          "text": [
            "Patience: None\n"
          ]
        },
        {
          "output_type": "stream",
          "name": "stderr",
          "text": [
            "INFO:pyhealth.trainer:Patience: None\n"
          ]
        },
        {
          "output_type": "stream",
          "name": "stdout",
          "text": [
            "\n"
          ]
        },
        {
          "output_type": "stream",
          "name": "stderr",
          "text": [
            "INFO:pyhealth.trainer:\n"
          ]
        },
        {
          "output_type": "display_data",
          "data": {
            "text/plain": [
              "Epoch 0 / 1:   0%|          | 0/219 [00:00<?, ?it/s]"
            ],
            "application/vnd.jupyter.widget-view+json": {
              "version_major": 2,
              "version_minor": 0,
              "model_id": "2750023fb2bc420c875b3fde2cef2843"
            }
          },
          "metadata": {}
        },
        {
          "output_type": "stream",
          "name": "stdout",
          "text": [
            "--- Train epoch-0, step-219 ---\n"
          ]
        },
        {
          "output_type": "stream",
          "name": "stderr",
          "text": [
            "INFO:pyhealth.trainer:--- Train epoch-0, step-219 ---\n"
          ]
        },
        {
          "output_type": "stream",
          "name": "stdout",
          "text": [
            "loss: 0.4454\n"
          ]
        },
        {
          "output_type": "stream",
          "name": "stderr",
          "text": [
            "INFO:pyhealth.trainer:loss: 0.4454\n",
            "Evaluation: 100%|██████████| 32/32 [00:02<00:00, 14.52it/s]"
          ]
        },
        {
          "output_type": "stream",
          "name": "stdout",
          "text": [
            "--- Eval epoch-0, step-219 ---\n"
          ]
        },
        {
          "output_type": "stream",
          "name": "stderr",
          "text": [
            "\n",
            "INFO:pyhealth.trainer:--- Eval epoch-0, step-219 ---\n"
          ]
        },
        {
          "output_type": "stream",
          "name": "stdout",
          "text": [
            "pr_auc: 0.2287\n"
          ]
        },
        {
          "output_type": "stream",
          "name": "stderr",
          "text": [
            "INFO:pyhealth.trainer:pr_auc: 0.2287\n"
          ]
        },
        {
          "output_type": "stream",
          "name": "stdout",
          "text": [
            "roc_auc: 0.6607\n"
          ]
        },
        {
          "output_type": "stream",
          "name": "stderr",
          "text": [
            "INFO:pyhealth.trainer:roc_auc: 0.6607\n"
          ]
        },
        {
          "output_type": "stream",
          "name": "stdout",
          "text": [
            "f1: 0.0000\n"
          ]
        },
        {
          "output_type": "stream",
          "name": "stderr",
          "text": [
            "INFO:pyhealth.trainer:f1: 0.0000\n"
          ]
        },
        {
          "output_type": "stream",
          "name": "stdout",
          "text": [
            "loss: 0.4239\n"
          ]
        },
        {
          "output_type": "stream",
          "name": "stderr",
          "text": [
            "INFO:pyhealth.trainer:loss: 0.4239\n"
          ]
        }
      ]
    },
    {
      "cell_type": "markdown",
      "source": [
        "## Step 5: Evaluate Model"
      ],
      "metadata": {
        "id": "SxaKfhbubrS5"
      }
    },
    {
      "cell_type": "code",
      "source": [
        "trainer.evaluate(test_loader)"
      ],
      "metadata": {
        "colab": {
          "base_uri": "https://localhost:8080/"
        },
        "id": "6Dxfs5LKbwkq",
        "outputId": "25b9f7bd-bb9d-4702-a3af-cc4c67b02cb8"
      },
      "execution_count": 6,
      "outputs": [
        {
          "output_type": "stream",
          "name": "stderr",
          "text": [
            "Evaluation: 100%|██████████| 63/63 [00:04<00:00, 14.34it/s]\n"
          ]
        },
        {
          "output_type": "execute_result",
          "data": {
            "text/plain": [
              "{'pr_auc': 0.2584620467432758,\n",
              " 'roc_auc': 0.6359064935064935,\n",
              " 'f1': 0.0,\n",
              " 'loss': 0.4518731695318979}"
            ]
          },
          "metadata": {},
          "execution_count": 6
        }
      ]
    }
  ]
}